# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import torch
import glob

from setuptools import find_packages, setup

from torch.utils.cpp_extension import (
    CppExtension,
    CUDAExtension,
    BuildExtension,
    CUDA_HOME,
)

library_name = "MinkowskiEngineNova"


def get_extensions():
    debug_mode = os.getenv("DEBUG", "0") == "1"
    use_cuda = os.getenv("USE_CUDA", "1") == "1"
    if debug_mode:
        print("Compiling in debug mode")

    use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
    extension = CUDAExtension if use_cuda else CppExtension

    extra_link_args = []
    extra_compile_args = {
        "cxx": [
            "-O3" if not debug_mode else "-O0",
            "-fdiagnostics-color=always",
            "-DPy_LIMITED_API=0x03090000",  # min CPython version 3.9
        ],
        "nvcc": [
            "-O3" if not debug_mode else "-O0",
        ],
    }
    if debug_mode:
        extra_compile_args["cxx"].append("-g")
        extra_compile_args["nvcc"].append("-g")
        extra_link_args.extend(["-O0", "-g"])

    this_dir = os.path.dirname(os.path.curdir)
    extensions_dir = os.path.join(this_dir, library_name, "csrc")
    sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))

    extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
    cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))

    if use_cuda:
        sources += cuda_sources

    ext_modules = [
        extension(
            f"{library_name}._C",
            sources,
            extra_compile_args=extra_compile_args,
            extra_link_args=extra_link_args,
            py_limited_api=False,
        )
    ]

    return ext_modules


class CustomBuildExt(BuildExtension):
    def build_extensions(self):
        # 设置构建目录
        custom_build_temp = "./build/temp"
        custom_build_lib = "./build/lib"

        # 更新默认构建路径
        self.build_temp = os.path.abspath(custom_build_temp)
        self.build_lib = os.path.abspath(custom_build_lib)

        print(f"[INFO] Using build_temp: {self.build_temp}")
        print(f"[INFO] Using build_lib: {self.build_lib}")

        super().build_extensions()


def create_compile_commands(build_path="build/temp"):
    build_path = os.path.realpath(build_path)
    ninja_files = []
    for root, dirs, files in os.walk(build_path):
        for file in files:
            if file == "build.ninja":
                ninja_files.append(os.path.join(root, file))
    print(ninja_files)
    assert len(ninja_files) == 1, "The number of 'build.ninja' file must be 1."

    real_path = os.path.realpath(ninja_files[0])

    output_path = os.path.join(".", "compile_commands.json")
    result = os.popen(f"ninja -f {real_path} -t compdb")
    with open(output_path, "w+") as f:
        f.write(result.read())
    print(f"Saved compile_commands.json at {os.path.relpath(output_path)}")


setup(
    name=library_name,
    version="0.0.1",
    packages=find_packages(),
    ext_modules=get_extensions(),
    install_requires=["torch"],
    description="Example of PyTorch C++ and CUDA extensions",
    long_description=open("README.md").read(),
    long_description_content_type="text/markdown",
    url="https://github.com/pytorch/extension-cpp",
    cmdclass={"build_ext": CustomBuildExt.with_options(use_ninja=True)},
    options={},
)
create_compile_commands()
